RBONN: Recurrent Bilinear Optimization for a Binary Neural Network

83

Algorithm 7 RBONN training.

Input: a minibatch of inputs and their labels, real-valued weights w, recurrent model

weights U, scaling factor matrix A, learning rates η1, η2 and η3.

Output: updated real-valued weights wt+1, updated scaling factor matrix At+1, and up-

dated recurrent model weights U t+1.

1: while Forward propagation do

2:

bwt sign(wt).

3:

bat

insign(at

in).

4:

Features calculation using Eq. 6.36

5:

Loss calculation using Eq. 6.68

6: end while

7: while Backward propagation do

8:

Computing

∂L

At ,

∂L

wt , and

∂L

∂U t using Eq. 6.70, 6.72, and 3.136.

9:

Update At+1, wt+1, and U t+1 according to Eqs. 6.69, 6.44, and 6.50, respectively.

10: end while

where w = diag(w11, · · · ,wCout1). And we judge when asynchronous convergence

occurs in optimization based on (¬D(w

i))D(Ai) = 1, where the density function is

defined as

D(xi) =



1

if ranking(σ(x)i)>T ,

0

otherwise,

(3.134)

where T is defined by T = int(Cout×τ). τ is the hyperparameter that denotes the threshold.

σ(x)i denotes the i-th eigenvalue of diagonal matrix x, and xi denotes the i-th row of matrix

x. Finally, we define the optimization of U as

U t+1 = |U t η3

∂L

∂U t |,

(3.135)

∂L

∂U t = ∂LS

wtDReLU(wt1, At),

(3.136)

where η3 is the learning rate of U. We elaborate on the RBONN training process outlined

in Algorithm 13.

3.8.3

Discussion

In this section, we first review the related methods on “gradient approximation” of BNNs,

then further discuss the difference of RBONN with the related methods and analyze the

effectiveness of the proposed RBONN.

In particular, BNN [99] directly unitizes the Straight-Through-Estimator in the training

stage to calculate the gradient of weights and activations as

bwi,j

wi,j

= 1|wi,j|<1, bai,j

ai,j

= 1|ai,j|<1

(3.137)

which suffers from an obvious gradient mismatch between the gradient of the binarization

function. Intuitively, the Bi-Real Net [159] designs an approximate binarization function

that can help alleviate the gradient mismatch in backward propagation as

bai,j

ai,j

=

1.2 + 2ai,j,

1ai,j < 0,

22ai,j,

0ai,j < 1,

10,

otherwise,

(3.138)